package cn.zhxu.toys.client;

import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.net.ConnectException;
import java.net.URI;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.task.AsyncListenableTaskExecutor;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.AsyncClientHttpRequest;
import org.springframework.http.client.AsyncClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.*;
import org.springframework.web.client.*;


public class RetryRestTemplate extends AsyncRestTemplate {

    static Logger log = LoggerFactory.getLogger(RetryRestTemplate.class);

    private int maxRetryTimes = 5;

    public RetryRestTemplate() {
    }
    
    public RetryRestTemplate(int maxRetryTimes) {
		this.maxRetryTimes = maxRetryTimes;
	}

	public RetryRestTemplate(AsyncListenableTaskExecutor taskExecutor) {
        super(taskExecutor);
    }

    public RetryRestTemplate(AsyncClientHttpRequestFactory asyncRequestFactory) {
        super(asyncRequestFactory);
    }

    public RetryRestTemplate(AsyncClientHttpRequestFactory asyncRequestFactory, ClientHttpRequestFactory syncRequestFactory) {
        super(asyncRequestFactory, syncRequestFactory);
    }

    public RetryRestTemplate(AsyncClientHttpRequestFactory requestFactory, RestTemplate restTemplate) {
        super(requestFactory, restTemplate);
    }


    class RetryFailureCallback<T> implements FailureCallback {

    	private DelegateListenableFuture<T> delegateFuture;
        private URI url;
        private HttpMethod method;
        private AsyncRequestCallback requestCallback;
        private ResponseExtractor<T> responseExtractor;
        private SuccessCallback<? super T> successCallback;
        private FailureCallback failureCallback;
        private int failureTimes;

        RetryFailureCallback(DelegateListenableFuture<T> delegateFuture, URI url, HttpMethod method, AsyncRequestCallback requestCallback, ResponseExtractor<T> responseExtractor,
                             SuccessCallback<? super T> successCallback, FailureCallback failureCallback) {
        	this.delegateFuture = delegateFuture;
            this.url = url;
            this.method = method;
            this.requestCallback = requestCallback;
            this.responseExtractor = responseExtractor;
            this.successCallback = successCallback;
            this.failureCallback = failureCallback;
            this.failureTimes = 0;
        }

        @Override
        public void onFailure(Throwable ex) {
            if (ex instanceof ConnectException && failureTimes < maxRetryTimes) {
                log.info("重试" + (failureTimes + 1) + "次 [" + method + ":" + url + "]");
                ListenableFuture<T> future = doExecute(url, method, requestCallback, responseExtractor);
                future.addCallback(successCallback, this);
                delegateFuture.setDelegate(future);
            } else {
                failureCallback.onFailure(ex);
            }
            failureTimes++;
        }

    }


    protected <T> ListenableFuture<T> doRawExecute(URI url, HttpMethod method, AsyncRequestCallback requestCallback,
                                                ResponseExtractor<T> responseExtractor) throws RestClientException {
        Assert.notNull(url, "'url' must not be null");
        Assert.notNull(method, "'method' must not be null");
        try {
            AsyncClientHttpRequest request = createAsyncRequest(url, method);
            if (requestCallback != null) {
                requestCallback.doWithRequest(request);
            }
            ListenableFuture<ClientHttpResponse> responseFuture = request.executeAsync();
            return new RawResponseExtractorFuture<T>(method, url, responseFuture, responseExtractor);
        }
        catch (IOException ex) {
            throw new ResourceAccessException("I/O error on " + method.name() +
                " request for \"" + url + "\":" + ex.getMessage(), ex);
        }
    }


    @Override
    protected <T> ListenableFuture<T> doExecute(URI url, HttpMethod method, AsyncRequestCallback requestCallback,
                                                ResponseExtractor<T> responseExtractor) throws RestClientException {
        ListenableFuture<T> future;
        if (responseExtractor instanceof RawResponseExtractor) {
            future = doRawExecute(url, method, requestCallback, responseExtractor);
        } else {
            future = super.doExecute(url, method, requestCallback, responseExtractor);
        }
        return new DelegateListenableFuture<>(future, url, method, requestCallback, responseExtractor);
    }

    class DelegateListenableFuture<T> implements ListenableFuture<T> {

        private ListenableFuture<T> delegate;
        private URI url;
        private HttpMethod method;
        private AsyncRequestCallback requestCallback;
        private ResponseExtractor<T> responseExtractor;

        DelegateListenableFuture(ListenableFuture<T> delegate, URI url, HttpMethod method, AsyncRequestCallback requestCallback, ResponseExtractor<T> responseExtractor) {
            this.delegate = delegate;
            this.url = url;
            this.method = method;
            this.requestCallback = requestCallback;
            this.responseExtractor = responseExtractor;
        }

        @Override
        public void addCallback(ListenableFutureCallback<? super T> callback) {
            delegate.addCallback(callback, new RetryFailureCallback<>(this, url, method, requestCallback, responseExtractor, callback, callback));
        }

        @Override
        public void addCallback(SuccessCallback<? super T> successCallback, FailureCallback failureCallback) {
        	if (!(failureCallback instanceof RetryFailureCallback)) {
        		failureCallback = new RetryFailureCallback<>(this, url, method, requestCallback, 
        				responseExtractor, successCallback, failureCallback);
        	}
            delegate.addCallback(successCallback, failureCallback);
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return delegate.cancel(mayInterruptIfRunning);
        }

        @Override
        public boolean isCancelled() {
            return delegate.isCancelled();
        }

        @Override
        public boolean isDone() {
            return delegate.isDone();
        }

        @Override
        public T get() throws InterruptedException, ExecutionException {
        	throw new IllegalStateException();
        }

        @Override
        public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
        	throw new IllegalStateException();
        }
        
        private void setDelegate(ListenableFuture<T> delegate) {
        	this.delegate = delegate;
        }
        
    }


    @Override
    protected <T> ResponseExtractor<ResponseEntity<T>> responseEntityExtractor(Type responseType) {
        // 支持直接返回 InputSteam
        if (isRawResponseType(responseType)) {
            return new RawResponseExtractor() {
                @Override
                public Object extractData(ClientHttpResponse response) throws IOException {
                    return new ResponseEntity(response.getBody(), response.getHeaders(), response.getStatusCode());
                }
            };
        }
        return super.responseEntityExtractor(responseType);
    }

    interface RawResponseExtractor<T> extends ResponseExtractor<T> { }

    private class RawResponseExtractorFuture<T> extends ListenableFutureAdapter<T, ClientHttpResponse> {

        private final HttpMethod method;

        private final URI url;

        private final ResponseExtractor<T> responseExtractor;

        public RawResponseExtractorFuture(HttpMethod method, URI url,
                                       ListenableFuture<ClientHttpResponse> clientHttpResponseFuture, ResponseExtractor<T> responseExtractor) {
            super(clientHttpResponseFuture);
            this.method = method;
            this.url = url;
            this.responseExtractor = responseExtractor;
        }

        @Override
        protected final T adapt(ClientHttpResponse response) throws ExecutionException {
            try {
                if (!getErrorHandler().hasError(response)) {
                    logResponseStatus(this.method, this.url, response);
                }
                else {
                    handleResponseError(this.method, this.url, response);
                }
                return convertResponse(response);
            }
            catch (Throwable ex) {
                throw new ExecutionException(ex);
            }
            // 不关闭 response
        }

        protected T convertResponse(ClientHttpResponse response) throws IOException {
            return (this.responseExtractor != null ? this.responseExtractor.extractData(response) : null);
        }
    }


    private boolean isRawResponseType(Type responseType) {
        return responseType instanceof Class<?> && InputStream.class.isAssignableFrom((Class<?>) responseType);
    }


    public int getMaxRetryTimes() {
        return maxRetryTimes;
    }

    public void setMaxRetryTimes(int maxRetryTimes) {
        this.maxRetryTimes = maxRetryTimes;
    }

    private void logResponseStatus(HttpMethod method, URI url, ClientHttpResponse response) {
        if (logger.isDebugEnabled()) {
            try {
                logger.debug("Async " + method.name() + " request for \"" + url + "\" resulted in " +
                    response.getRawStatusCode() + " (" + response.getStatusText() + ")");
            }
            catch (IOException ex) {
                // ignore
            }
        }
    }

    private void handleResponseError(HttpMethod method, URI url, ClientHttpResponse response) throws IOException {
        if (logger.isWarnEnabled()) {
            try {
                logger.warn("Async " + method.name() + " request for \"" + url + "\" resulted in " +
                    response.getRawStatusCode() + " (" + response.getStatusText() + "); invoking error handler");
            }
            catch (IOException ex) {
                // ignore
            }
        }
        getErrorHandler().handleError(response);
    }

}